#!/usr/bin/env python3
"""
Top-level driver for the loop-fluctuation correlation simulation.
[… docstring unchanged …]
"""

import argparse
import os
import sys
import yaml
import numpy as np

sys.path.append(os.path.dirname(__file__))  # ensure scripts/ is importable

from scripts.generate_flip_counts import generate_flip_counts
from scripts.sample_gauge_fields import sample_gauge_fields
from scripts.run_correlation import run_correlation
from scripts.generate_report import generate_report


def load_config(path: str) -> dict:
    with open(path, "r", encoding="utf-8") as f:
        return yaml.safe_load(f)


def main() -> None:
    parser = argparse.ArgumentParser(description="Run the loop-fluctuation simulation pipeline")
    parser.add_argument("--config", required=True, help="YAML configuration file")
    parser.add_argument("--output-dir", required=True, help="Directory for outputs")
    parser.add_argument("--lattice-size", type=int, help="Override loop_fluctuation.lattice_size from config")
    args = parser.parse_args()

    cfg = load_config(args.config)

    # ----- 1. pull sim-specific params ------------------------------------------------
    if "loop_fluctuation" in cfg:
        sim_cfg = cfg["loop_fluctuation"]
        # Override lattice_size if requested
        if args.lattice_size is not None:
            sim_cfg["lattice_size"] = args.lattice_size
        kernel_paths = cfg.get("kernel_paths", {})
        gauge_groups = list(kernel_paths.keys())
    else:  # fallback legacy structure
        sim_cfg = cfg.get("parameters", {})
        # Override lattice_size if requested
        if args.lattice_size is not None:
            sim_cfg["lattice_size"] = args.lattice_size
        gauge_groups = sim_cfg.get("gauge_groups", ["U1", "SU2", "SU3"])
        kernel_paths = {g: None for g in gauge_groups}

    lattice_size   = sim_cfg.get("lattice_size")
    context_depth  = sim_cfg.get("context_depth", sim_cfg.get("N", 2))
    steps_per_link = sim_cfg.get("steps_per_link", 1000)
    seed           = sim_cfg.get("seed")
    # Read the list of loop sizes from the top-level YAML
    loop_sizes     = cfg.get("loop_sizes", [1, 2, 3, 4])

    os.makedirs(args.output_dir, exist_ok=True)

    # ----- 2. flip counts ------------------------------------------------------------
    fc = generate_flip_counts(
        lattice_size=lattice_size,
        seed=seed,
        context_depth=context_depth,
        steps_per_link=steps_per_link,
    )
    # Write flip counts to a per-L file (flattened, no extra subfolders)
    fc_filename = f"flip_counts_L{lattice_size}.npy"
    fc_path = os.path.join(args.output_dir, fc_filename)
    np.save(fc_path, fc)
    print(f"[DEBUG] Saved flip counts for L={lattice_size} → {fc_path}")

    # ----- 3. sample gauge fields ----------------------------------------------------
    gauge_paths: dict[str, str] = {}
    for g in gauge_groups:
        g_path = os.path.join(args.output_dir, f"gauge_{g}.npy")
        sample_gauge_fields(
            flip_counts=fc,
            gauge_group=g,
            kernel_path=None,          # kernels disabled for simplicity
            output_path=g_path,
            lattice_size=lattice_size,
            trials=1,
        )
        gauge_paths[g] = g_path

    # ----- 4. correlation analysis ---------------------------------------------------
    corr_csv = os.path.join(args.output_dir, "correlation_full.csv")
    first = True
    for L in loop_sizes:
        run_correlation(
            flip_counts_path=fc_path,
            gauge_paths=gauge_paths,
            output_csv=corr_csv,
            loop_sizes=[L],
            bootstrap_samples=cfg.get("bootstrap", 200),
            append=not first,
        )
        first = False

    # ----- 5. summary report ---------------------------------------------------------
    report_md = os.path.join(args.output_dir, "correlation_report.md")
    generate_report(
        csv_path=corr_csv,
        report_md=report_md,
        plot_dir=args.output_dir,
    )


if __name__ == "__main__":
    main()
